// STL
#include <cmath>                       // std::nexttoward()
#include <utility>                     // std::pair<>, std::make_pair()
#include <vector>                      // std::vector<>

// jScience
#include "jutility.hpp"                // sort_indices_desc()

// stats++
#include "statsxx/postprocess/ROC.hpp" // ROC_pt, divide_and_calc_ROCs(), ROC_pt_at_thresh()
#include "statsxx/statistics.hpp"      // statistics::mean(), ::stddev()


//
// DESC: Threshold averaging of ROC curves.
//
// INPUT:
//
//     L         : set of test examples
//     f(i)      : probabilistic classifier's estimate that example i is positive
//     ntrhesh   : number of threshold divisions
//                 NOTE: Number of points is (nthresh+1).
//     nROCs     : number of ROCs to calculate
//
// OUTPUT:
//
//     ROC       : averaged ROC curve
//     stddevROC : standard deviation of ROC curve in both x (TPR) and y (FPR); score not used
//
// NOTE: This implementation (including notation) is that of Algorithm 4 in:
//
//     T. Fawcett "An introduction to ROC analysis" Pattern Recognition Letters 27, 861--874 (2006)
//
// TODO: NOTE: One thing that cannot be done in this calculation is output each ROC curve individually.
//
inline std::pair<
                 std::vector<ROC_pt>, // ROC
                 std::vector<ROC_pt>  // stddevROC
                 > calc_ROC_curve_thresh_avg(
                                             const std::vector<int>              &L,
                                             const std::vector<double>           &f,
                                             // -----
                                             const int                            nthresh,
                                             const int                            nROCs
                                             )
{
    std::vector<ROC_pt> ROC;
    std::vector<ROC_pt> stddevROC;

    // DIVIDE AND CALCULATE nROCs ROC CURVES
    //
    // NOTE: Storing all of the ROC curves can be extremely memory intensive.
    //
    // TODO: NOTE: ... Perhaps think of an alternative way to do this.
    //
    std::vector<std::vector<ROC_pt>> ROCs = divide_and_calc_ROCs(
                                                                 L,
                                                                 f,
                                                                 // -----
                                                                 nROCs
                                                                 );

    // SORT *ALL* ROC SCORES BY DECREASING f VALUE
    std::vector<double> T;
    for( auto idx : sort_indices_desc(f) )
    {
        T.push_back(f[idx]);
    }

    // NOTE: In order to capture the (0,0) point correctly, we need a threshold just above that which contains any points.
    double T0 = std::nexttoward(T[0], (T[0] + 1.));

    T.insert(T.begin(), T0);

    // LOOP OVER ALL THRESHOLDS

    // break T into nthresh thresholds
    // NOTE: There are T.size() total points, but (T.size()-1) thresholds.
    int npts = (T.size()-1)/nthresh;
    int rem  = (T.size()-1)%nthresh;

    int tidx = 0;

    for( int thresh = 0; thresh <= nthresh; ++thresh )
    {
        std::vector<double> FPR;
        std::vector<double> TPR;

        for( int i = 0; i < nROCs; ++i )
        {
            ROC_pt p = ROC_pt_at_thresh(
                                        ROCs[i],
                                        T[tidx]
                                        );

            FPR.push_back( p.FPR );
            TPR.push_back( p.TPR );
        }

        ROC_pt tmp_ROC;
        tmp_ROC.FPR   = statistics::mean( FPR );
        tmp_ROC.TPR   = statistics::mean( TPR );
        tmp_ROC.score = T[tidx];
        ROC.push_back(tmp_ROC);

        ROC_pt tmp_stddevROC;
        tmp_stddevROC.FPR   = statistics::stddev( FPR );
        tmp_stddevROC.TPR   = statistics::stddev( TPR );
        stddevROC.push_back(tmp_stddevROC);

        // NOTE: This adds any additional points (due to unequal number of points per threshold) to the first sets.
        if( thresh < rem )
        {
            tidx += (npts + 1);
        }
        else
        {
            tidx += npts;
        }
    }

    return std::make_pair(
                          ROC,
                          stddevROC
                          );
}
